-
Notifications
You must be signed in to change notification settings - Fork 155
Fix issues with split and split_dims #1828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
pytensor/tensor/basic.py
Outdated
| ): | ||
| # All elements already have the right number of dimensions, so we | ||
| # can just join them directly. | ||
| return join(0, *x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't equivalent to stack below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because stack adds a dimension. This was causing a bug in split_dims where we ask explicitly ask for ndims=1, passing a sequence of 1d tensors, but then we get back a 2d tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so my understand is this function is supposed to do what np.array(x) would do. I think the ndim is more of an assert, it should fail when the output of np.array (in our case the symbolic equivalent) would yield something different. So in that sense join is never valid as it keeps the same dimensions.
I want to revert and check if I'm missing something with the test that was failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. From my perspective the biggest issue is that as_tensor_variable(..., ndims=1) isn't idempotent -- sequential calls on the same input keep mutating the same graph. This is happening because of stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's odd because if it's already a single tensor variable (and not a list with one in it) it shouldn't do anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that first one seems wrong.
Even if fix it, I think our check for "sequence" on split_dims (or wherever the problem was) should be more like if isinstance(x, Sequence) or (isinstance(x, TensorVariable) and x.ndim == 1)
1d numpy arrays should also be valid, but maybe those pass the Sequence instance check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should remove the ndim argument altogether? numpy doesn't have it and I don't think we need it.
I thought it was just used for validation but it seems to affect non-raising outcomes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should remove the ndim argument altogether? numpy doesn't have it and I don't think we need it.
I thought it was just used for validation but it seems to affect non-raising outcomes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm +1 for removing it. I never knew it existed, and it seems like it's overloading the function.
If I had to guess though, it's exactly for this situation. We have an argument with type int | Variable | tuple[int | Variable]. The Variable, though, can be either a scalar or an array. So really the typing is something like int | Variable[ndim=0] | Variable[ndim=1] | tuple[int | Variable[ndim=0]. When we do the if not isinstance(shape, tuple): shape = (shape, ) we're ignoring the Variable[ndim=1] case. Calling as_tensor_variable(tuple[Variable[ndim=0]) -> Variable[ndim=1] makes sense to me, and matches the numpy behavior. In this case we're counting on the ndim=1 arugment to guard against the case of as_tensor_variable(tuple[Variable[ndim=1]) -> Variable[ndim=2].
Typing all this out, it seems like an abuse of the as_tensor_variable function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed. Would be really nice to be able to have those TensorVariable[ndim=0] types btw. Need to nerdsnipe some type hint lovers
4f38402 to
579566d
Compare
|
I reverted the changes to Something else I noticed was that we're passing |
|
No, better not to cast variables in node but raise like before. That's what shape ops always do. If a user passes a float as a shape argument it's likely a bug and this would mask it |
|
Someday I will merge a PR |
pytensor/tensor/reshape.py
Outdated
| ) | ||
|
|
||
| if not shape: | ||
| if empty_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about just shape.type.shape == (0,), for the variable case? Also if you standardize as_tensor_variable you don't need the variable vs non-variable case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also do we need the special squeeze branch or would the Op do the right thing anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests pass without it (as long as I adjust the existing test_split_size_zero_shape test to pass dtype int to the shape argument), so I guess not.
|
I'm happy with the PR. I'll fix the git history and merge |
Related to #1806 #1827
Fix bug when passing simple Tensor shape to split_dims
Change grad_undefined -> grad_disconnected for split_sizes in SplitOp (see #1827 for more context)